#!/usr/bin/env python
import os
import pandas as pd
import numpy as np

kids_path = "data/prestacked_stacks.csv"
kids_meta_path = "data/prestacked_meta.csv"
dr5_path = "data/prestacked_stacks_dr5.csv"
dr5_meta_path = "data/prestacked_meta_dr5.csv"

out_stacks = "data/prestacked_stacks_combined.csv"
out_meta = "data/prestacked_meta_combined.csv"

def main():
    if not (os.path.exists(kids_path) and os.path.exists(dr5_path)):
        raise SystemExit("Missing KiDS or DR5 prestacks.")

    kids = pd.read_csv(kids_path)
    dr5 = pd.read_csv(dr5_path)

    # Ensure same columns
    cols = ["stack_id","R_G_bin","Mstar_bin","b","gamma_t","weight"]
    for df, name in [(kids,"KiDS"), (dr5,"DR5")]:
        for c in cols:
            if c not in df.columns:
                raise SystemExit(f"Missing column '{c}' in {name} prestacks.")

    all_stacks = pd.concat([kids[cols], dr5[cols]], ignore_index=True)

    # Weighted combine: gamma_comb = sum(w*gamma) / sum(w)
    def combine_group(g):
        w = g["weight"].values
        gt = g["gamma_t"].values
        sum_w = np.sum(w)
        sum_we = np.sum(w * gt)
        if sum_w > 0:
            gamma = sum_we / sum_w
        else:
            gamma = np.nan
        return pd.Series(
            {
                "stack_id": g["stack_id"].iloc[0],
                "R_G_bin": g["R_G_bin"].iloc[0],
                "Mstar_bin": g["Mstar_bin"].iloc[0],
                "b": g["b"].iloc[0],
                "gamma_t": gamma,
                "weight": sum_w,
            }
        )

    grouped = all_stacks.groupby(["stack_id","R_G_bin","Mstar_bin","b"], as_index=False)
    combined = grouped.apply(combine_group)

    os.makedirs("data", exist_ok=True)
    combined.to_csv(out_stacks, index=False)
    print(f"[info] Wrote combined stacks to {out_stacks} ({len(combined)} rows).")

    # Meta: just add n_lenses and average R_G_mean_kpc
    if os.path.exists(kids_meta_path) and os.path.exists(dr5_meta_path):
        kids_meta = pd.read_csv(kids_meta_path)
        dr5_meta = pd.read_csv(dr5_meta_path)
        all_meta = pd.concat([kids_meta, dr5_meta], ignore_index=True)
        meta_comb = (
            all_meta
            .groupby("stack_id", as_index=False)
            .agg(
                n_lenses=("n_lenses","sum"),
                R_G_mean_kpc=("R_G_mean_kpc","mean")
            )
        )
        meta_comb.to_csv(out_meta, index=False)
        print(f"[info] Wrote combined meta to {out_meta} ({len(meta_comb)} stacks).")
    else:
        print("[warn] Missing meta files; meta not combined.")

if __name__ == "__main__":
    main()
